import os
from contextlib import nullcontext
from functools import wraps

ROOT_DIR = os.path.dirname(__file__)

# The root directory for placing temp directories and files.
# None for auto-detection from the local system.
TEMP_DIR = None


def parametrize(argnames, argvalues, subtest=True):
    """A decorator function that runs a test function with sub-tests having the
    defined arguments and ranges.

    - The decorated test function need to take extra (optionally renamed)
      arguments for sub tests.
    - This somehow mimics pytest.mark.parametrize, but not fully compatible.

    Args:
        argnames (*): name(s) of arguments.
            - str: a comma separated list of arg names
            - iterator: strs for arg names
        argvalues (iterator): values(s) of arguments. Each value should be a
            tuple if argnames has multiple values.
        subtest (*): whether and how to wrap each iteration as a subtest
            - str: a comma separated list of subtest arg names
            - iterator: strs for subtest arg names
            - True: wrap all args in the subtest
            - False: no subtest

    Example 1:

        @parametrize('i', [2, 3, 5, 7])
        def test_func(self, index):
            self.assertEqual(0, index)

        is roughly equivalent to:

        def test_func(self):
            for index in [2, 3, 5, 7]:
                with self.subTest(i=index):
                    self.assertEqual(0, index)

    Example 2:

        @parametrize('i, j, k', [(1, 10, 100), (2, 20, 200), (3, 30, 300)])
        def test_func(self, index1, index2, index3):
            self.assertEqual((0, 0, 0), (index1, index2, index3))

        or

        @parametrize(['i', 'j', 'k'], [(1, 10, 100), (2, 20, 200), (3, 30, 300)])
        def test_func(self, index1, index2, index3):
            self.assertEqual((0, 0, 0), (index1, index2, index3))

    Example 3:

        @parametrize('k', range(21, 25))
        @parametrize('j', range(11, 15))
        @parametrize('i', range(1, 5))
        def test_func(self, i, j, k):
            self.assertEqual((0, 10, 20), (i, j, k))
    """

    def decorator(test_func):
        @wraps(test_func)
        def wrapper(self, *args, **kwargs):
            for argvalue in argvalues:
                argvalue = (argvalue,) if force_tuple and not isinstance(argvalue, tuple) else argvalue
                if len(argnames) != len(argvalue):
                    raise ValueError(f'Unmatched arguments length: {argnames}, {argvalue}')
                extra_args = dict(zip(argnames, argvalue))
                subtest_args = extra_args if isinstance(subtest, bool) else {k: v for k, v in extra_args.items() if k in subtest}
                with self.subTest(**subtest_args) if subtest is not False else nullcontext():
                    test_func(self, *(tuple(extra_args.values()) + args), **kwargs)
        return wrapper

    if isinstance(argnames, str):
        argnames = tuple(x.strip() for x in argnames.split(',') if x.strip())
    else:
        argnames = tuple(argnames)

    if isinstance(subtest, str):
        subtest = {x.strip() for x in subtest.split(',') if x.strip()}
    elif not isinstance(subtest, bool):
        subtest = set(subtest)

    force_tuple = len(argnames) == 1

    return decorator
